Participation name on the Algonatus leaderboard: "LukeK" (Score: 50.5408)
We ran and tested all models locally on a single AMD GPU. Furthermore, any additional data used (on top of the challenge data) as well as this notebook can be found in our GitHub repository (e.g., CSV file encoding the COCO ID, NSD ID, and caption matches)
Recent advances in machine learning have enabled the development of computational models that can predict brain responses to visual stimuli. In this study, we investigate the effectiveness of the CLIP (Contrastive Language-Image Pre-Training) model. This is a state-of-the-art neural network model developed by OpenAI to predict brain activation across the whole brain in response to image stimuli in the Algonauts 2023 Challenge. We proposed three CLIP-based solutions: a Text model, a Vision model, and a combination of features from the Text and Vision models. We hypothesized that a combination of features from the Text and Vision models would provide the best approximation of the brain responses. The data consists of fMRI recordings from participants viewing complex natural visual scenes. This combination of features will be able to effectively recognize objects and infer relationships between them by leveraging its cross-modal understanding. The results showed that the Text model was the worst performing model to predict brain activation, while the Vision model outperformed the other models. Unexpectedly, the Combination model performed equal or worse than the Vision model. As our final solution to the Algonauts 2023 Challenge, we propose combining layer 4 and 8 of the CLIP Vision model. Our findings suggest that the CLIP Vision model provides a more effective approach for predicting brain activation in response to visual stimuli than either the Text model or Combination model, particulary for the early layers.
The Algonaut Challenge 2023 is established to investigate how well current computational models are doing. It is an open challenge to stimulate the combination of the biological and artificial intelligence fields. Participants are expected to build computational models to predict brain responses for images from which brain data was held out. In this report we describe what computational model we propose to predict brain data.
We propose the use of CLIP (Contrastive Language-Image Pre-Training, Radford et al., 2021), which is a state-of-the-art neural network model developed by OpenAI that can understand the relationship between images and text. It is based on a transformer architecture, similar to those used in language models like GPT-3, and is pre-trained on a large amount of text and images in an unsupervised way. What sets CLIP apart from other models is its ability to learn cross-modal representations of images and text, allowing it to form the relationship between the two modalities. This is achieved by using a Vision Transformer, which splits an image into patches that are then linearly embedded and these vectors are then given into a standard transformer encoder (For more details, see Dosovitskiy et al., 2021). Additionally, its contrastive learning objective allows the model to associate text and images that refer to the same concepts. For example, the model learns that an image of a dog and a caption that says "a dog running in a park" is related to the concept of a dog.
The Algonauts 2023 Challenge data comes from the Natural Scenes Dataset (NSD, Allen et al., 2022), a massive 8-subjects dataset of 7T fMRI responses to images of natural scenes coming from the COCO database (Lin et al., 2014). The COCO dataset contains a large number of images labeled with object categories, captions, and annotations. As previously mentioned, the key advantage of CLIP is its ability to learn cross-modal representations of images and text, allowing it to understand the relationship between the two modalities (Radford et al., 2021). This is particularly relevant for the COCO dataset, which contains both visual and textual information, such as image captions and object annotations. By leveraging this cross-modal understanding, CLIP can more effectively learn to recognize concepts and infer relationships between them. Given that the Challenge evaluation score is computed over the whole visual brain, we hypothesized that a combined computational model, in which visual properties interact with more abstract semantic information, offers the best solution for this challenge. Also, because CLIP has been trained on a large amount of text and images in an unsupervised way, it can generalize to new and unseen data. This is crucial in Algonauts Challenge 2023, where the goal is to build a model that can understand language and vision in a way that allows it to predict new contexts and tasks. CLIP's pre-training on large amounts of data also means that it can be fine-tuned to smaller datasets, like those used in the Algonauts challenge, to further improve its performance.
Here, we propose a CLIP model that uses a ViT-B/32 Transformer architecture as an image encoder and a masked self-attention Transformer as a text encoder as a solution to the Algonauts 2023 challenge. To this extend, we created three models: (1) a CLIP vision model, (2) a CLIP Text model, and (3) a combination of features from the CLIP Text and CLIP Vision models.
As recently suggested by Marjieh et al. (2023) in the field of similarity judgments, stacking the representations of language-based models and models that do not rely on texts (including image-based models) is consistently providing the best approximation for human similarity judgements. We believed that similarity judgments can be applied to viewing complex natural visual scenes. Furthermore, they found that stacking the representations from CLIP image and CLIP text models outperforms the individual models. Therefore, we hypothesized that a combination of features from the CLIP Text and CLIP Vision models provide the best approximation of brain responses recorded of participants viewing complex natural visual scenes.
In the following we will load packages and functions to assist in answering the research question.
Importing all required packages and models to run this notebook.
# Import required packages
import os
import glob
import time
import numpy as np
import pandas as pd
import torch
import torch_directml
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from transformers import AutoProcessor, CLIPTextModel, CLIPVisionModel, PreTrainedModel
from torch.utils.data import Dataset, DataLoader
from sklearn.decomposition import IncrementalPCA, PCA
from sklearn.linear_model import LinearRegression
from sklearn.metrics import euclidean_distances
from scipy.stats import pearsonr as corr
from scipy.stats import spearmanr
from scipy.spatial.distance import squareform
from tqdm import tqdm
from typing import List, Dict, Tuple
from nilearn import datasets
from nilearn import plotting
import warnings
warnings.filterwarnings("ignore") # to reduce unnecessary output
c:\Users\luke-\Desktop\Python Repositories\algonauts-2023\.conda\lib\site-packages\tqdm\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
We used an AMD RX 5700XT GPU and AMD Ryzen 5 3600XT CPU to run this code. To enable GPU support we relied on torch_directml == 0.1.13.1.dev230413.
If you want to run this notebook on a cuda device, set AMD to False.
# Setup cuda device
global device
AMD = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if AMD:
device = torch_directml.device()
First we define the models we are going to use in this notebook. This is necessary, as we will use some of the functionalities throughout the notebook. We downloaded pretrained CLIPModels and the CLIPProcessor from huggingface.
# Defining Models
global vis_model, txt_model, processor
vis_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
txt_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
vis_model.eval()
txt_model.eval()
Some weights of the model checkpoint at openai/clip-vit-base-patch32 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.3.self_attn.out_proj.weight', 'text_model.encoder.layers.11.self_attn.out_proj.bias', 'text_model.encoder.layers.1.self_attn.q_proj.bias', 'text_model.encoder.layers.10.layer_norm2.bias', 'text_model.encoder.layers.0.mlp.fc2.bias', 'visual_projection.weight', 'text_model.encoder.layers.2.layer_norm2.bias', 'text_model.encoder.layers.3.self_attn.out_proj.bias', 'text_model.encoder.layers.2.layer_norm1.weight', 'text_model.encoder.layers.4.layer_norm1.bias', 'text_model.encoder.layers.2.layer_norm1.bias', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.4.layer_norm2.bias', 'text_model.encoder.layers.6.self_attn.out_proj.weight', 'text_model.encoder.layers.7.mlp.fc1.weight', 'text_model.encoder.layers.10.layer_norm2.weight', 'text_model.encoder.layers.6.mlp.fc1.bias', 'text_model.encoder.layers.9.layer_norm2.weight', 'text_model.encoder.layers.6.self_attn.q_proj.bias', 'text_model.encoder.layers.4.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.out_proj.weight', 'text_model.encoder.layers.5.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.4.self_attn.q_proj.weight', 'text_model.encoder.layers.10.mlp.fc1.bias', 'text_model.encoder.layers.3.mlp.fc2.bias', 'text_model.encoder.layers.9.mlp.fc1.bias', 'text_model.encoder.layers.2.self_attn.q_proj.bias', 'text_model.encoder.layers.4.self_attn.q_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.bias', 'text_model.encoder.layers.6.mlp.fc1.weight', 'text_model.encoder.layers.5.layer_norm2.bias', 'text_model.encoder.layers.1.self_attn.k_proj.weight', 'text_model.encoder.layers.9.self_attn.k_proj.bias', 'text_model.encoder.layers.3.self_attn.q_proj.bias', 'text_model.encoder.layers.2.mlp.fc1.weight', 'text_model.encoder.layers.4.mlp.fc1.weight', 'text_model.encoder.layers.8.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.q_proj.weight', 'text_model.encoder.layers.9.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.out_proj.weight', 'text_model.encoder.layers.1.self_attn.v_proj.bias', 'text_model.encoder.layers.9.layer_norm2.bias', 'text_model.encoder.layers.8.mlp.fc2.bias', 'text_model.encoder.layers.10.layer_norm1.weight', 'text_model.encoder.layers.3.mlp.fc1.bias', 'text_model.encoder.layers.2.self_attn.q_proj.weight', 'text_model.encoder.layers.8.layer_norm2.bias', 'text_model.encoder.layers.2.self_attn.k_proj.bias', 'text_model.encoder.layers.3.layer_norm2.bias', 'text_model.encoder.layers.8.mlp.fc1.bias', 'text_model.encoder.layers.8.self_attn.v_proj.bias', 'text_model.encoder.layers.7.mlp.fc2.bias', 'text_model.encoder.layers.1.mlp.fc2.weight', 'text_model.encoder.layers.4.mlp.fc1.bias', 'text_model.encoder.layers.6.mlp.fc2.bias', 'text_model.encoder.layers.3.self_attn.q_proj.weight', 'text_model.encoder.layers.11.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.10.self_attn.v_proj.bias', 'text_model.encoder.layers.6.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.q_proj.weight', 'text_model.encoder.layers.2.mlp.fc1.bias', 'text_model.embeddings.token_embedding.weight', 'text_model.final_layer_norm.bias', 'text_model.encoder.layers.10.layer_norm1.bias', 'text_model.embeddings.position_embedding.weight', 'text_model.encoder.layers.2.mlp.fc2.weight', 'text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.1.self_attn.v_proj.weight', 'text_model.encoder.layers.3.layer_norm1.bias', 'text_model.encoder.layers.7.self_attn.q_proj.bias', 'text_model.encoder.layers.11.layer_norm2.bias', 'text_model.encoder.layers.1.mlp.fc1.bias', 'text_model.encoder.layers.10.self_attn.out_proj.weight', 'text_model.encoder.layers.4.self_attn.k_proj.weight', 'text_model.encoder.layers.8.layer_norm1.bias', 'text_model.encoder.layers.5.mlp.fc2.weight', 'text_model.encoder.layers.3.self_attn.k_proj.weight', 'text_model.encoder.layers.11.layer_norm1.bias', 'text_model.encoder.layers.2.mlp.fc2.bias', 'text_model.encoder.layers.4.self_attn.v_proj.bias', 'text_model.encoder.layers.10.mlp.fc2.bias', 'text_model.embeddings.position_ids', 'text_model.encoder.layers.6.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.8.layer_norm2.weight', 'text_model.encoder.layers.6.mlp.fc2.weight', 'text_model.encoder.layers.11.layer_norm1.weight', 'text_model.encoder.layers.9.mlp.fc2.bias', 'text_model.encoder.layers.8.layer_norm1.weight', 'text_model.encoder.layers.4.mlp.fc2.weight', 'text_model.encoder.layers.11.layer_norm2.weight', 'text_model.encoder.layers.9.self_attn.k_proj.weight', 'text_model.encoder.layers.3.self_attn.v_proj.weight', 'text_model.encoder.layers.5.self_attn.out_proj.bias', 'text_model.encoder.layers.8.self_attn.q_proj.weight', 'text_model.encoder.layers.5.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.4.layer_norm1.weight', 'text_model.encoder.layers.10.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.v_proj.weight', 'text_model.encoder.layers.4.self_attn.out_proj.weight', 'text_model.encoder.layers.3.layer_norm1.weight', 'text_model.encoder.layers.2.self_attn.out_proj.bias', 'text_model.encoder.layers.11.self_attn.k_proj.bias', 'text_model.encoder.layers.6.layer_norm1.bias', 'text_model.encoder.layers.6.self_attn.k_proj.weight', 'text_model.encoder.layers.5.layer_norm2.weight', 'text_model.encoder.layers.1.layer_norm2.bias', 'text_model.encoder.layers.1.layer_norm1.weight', 'text_model.encoder.layers.7.layer_norm1.bias', 'logit_scale', 'text_model.encoder.layers.11.mlp.fc2.weight', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.1.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.weight', 'text_model.encoder.layers.7.mlp.fc1.bias', 'text_model.encoder.layers.1.self_attn.q_proj.weight', 'text_model.encoder.layers.1.self_attn.out_proj.bias', 'text_model.encoder.layers.6.layer_norm2.weight', 'text_model.encoder.layers.7.layer_norm2.bias', 'text_model.encoder.layers.5.mlp.fc1.bias', 'text_model.encoder.layers.2.self_attn.v_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.2.layer_norm2.weight', 'text_model.encoder.layers.9.layer_norm1.bias', 'text_model.encoder.layers.10.mlp.fc1.weight', 'text_model.final_layer_norm.weight', 'text_model.encoder.layers.5.layer_norm1.bias', 'text_model.encoder.layers.5.self_attn.k_proj.weight', 'text_model.encoder.layers.3.self_attn.k_proj.bias', 'text_model.encoder.layers.5.self_attn.out_proj.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.encoder.layers.11.self_attn.q_proj.bias', 'text_model.encoder.layers.10.self_attn.q_proj.bias', 'text_model.encoder.layers.2.self_attn.out_proj.weight', 'text_model.encoder.layers.7.self_attn.v_proj.bias', 'text_model.encoder.layers.5.self_attn.q_proj.weight', 'text_model.encoder.layers.9.self_attn.q_proj.bias', 'text_model.encoder.layers.6.self_attn.v_proj.bias', 'text_model.encoder.layers.9.self_attn.out_proj.weight', 'text_model.encoder.layers.7.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.10.self_attn.k_proj.weight', 'text_model.encoder.layers.11.self_attn.out_proj.weight', 'text_model.encoder.layers.9.self_attn.v_proj.bias', 'text_model.encoder.layers.1.layer_norm2.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.8.self_attn.q_proj.bias', 'text_model.encoder.layers.10.self_attn.q_proj.weight', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.2.self_attn.k_proj.weight', 'text_model.encoder.layers.9.self_attn.out_proj.bias', 'text_model.encoder.layers.6.layer_norm1.weight', 'text_model.encoder.layers.8.mlp.fc1.weight', 'text_model.encoder.layers.5.self_attn.v_proj.weight', 'text_model.encoder.layers.11.mlp.fc2.bias', 'text_model.encoder.layers.4.layer_norm2.weight', 'text_model.encoder.layers.11.mlp.fc1.weight', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.8.mlp.fc2.weight', 'text_model.encoder.layers.2.self_attn.v_proj.weight', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.8.self_attn.out_proj.bias', 'text_model.encoder.layers.7.self_attn.out_proj.bias', 'text_model.encoder.layers.1.self_attn.out_proj.weight', 'text_model.encoder.layers.6.self_attn.v_proj.weight', 'text_model.encoder.layers.11.self_attn.v_proj.bias', 'text_model.encoder.layers.11.self_attn.k_proj.weight', 'text_model.encoder.layers.3.mlp.fc2.weight', 'text_model.encoder.layers.5.self_attn.q_proj.bias', 'text_model.encoder.layers.1.mlp.fc1.weight', 'text_model.encoder.layers.11.self_attn.v_proj.weight', 'text_model.encoder.layers.3.self_attn.v_proj.bias', 'text_model.encoder.layers.4.mlp.fc2.bias', 'text_model.encoder.layers.5.self_attn.k_proj.bias', 'text_model.encoder.layers.9.mlp.fc1.weight', 'text_model.encoder.layers.3.mlp.fc1.weight', 'text_model.encoder.layers.10.self_attn.v_proj.weight', 'text_model.encoder.layers.4.self_attn.out_proj.bias', 'text_model.encoder.layers.11.mlp.fc1.bias', 'text_model.encoder.layers.1.mlp.fc2.bias', 'text_model.encoder.layers.10.self_attn.k_proj.bias', 'text_model.encoder.layers.3.layer_norm2.weight', 'text_model.encoder.layers.4.self_attn.v_proj.weight', 'text_model.encoder.layers.6.self_attn.out_proj.bias', 'text_model.encoder.layers.7.mlp.fc2.weight', 'text_model.encoder.layers.5.self_attn.v_proj.bias', 'text_projection.weight', 'text_model.encoder.layers.6.self_attn.k_proj.bias', 'text_model.encoder.layers.10.self_attn.out_proj.bias', 'text_model.encoder.layers.7.layer_norm2.weight', 'text_model.encoder.layers.7.self_attn.v_proj.weight', 'text_model.encoder.layers.8.self_attn.k_proj.weight', 'text_model.encoder.layers.9.layer_norm1.weight', 'text_model.encoder.layers.9.self_attn.v_proj.weight', 'text_model.encoder.layers.5.mlp.fc2.bias'] - This IS expected if you are initializing CLIPVisionModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing CLIPVisionModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of the model checkpoint at openai/clip-vit-base-patch32 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.9.layer_norm1.weight', 'vision_model.encoder.layers.4.self_attn.k_proj.weight', 'vision_model.encoder.layers.9.self_attn.q_proj.bias', 'vision_model.encoder.layers.0.layer_norm1.weight', 'vision_model.encoder.layers.2.mlp.fc1.weight', 'visual_projection.weight', 'vision_model.encoder.layers.5.mlp.fc2.bias', 'vision_model.encoder.layers.11.self_attn.k_proj.bias', 'vision_model.encoder.layers.4.self_attn.v_proj.bias', 'vision_model.encoder.layers.10.layer_norm1.bias', 'vision_model.encoder.layers.1.layer_norm1.bias', 'vision_model.encoder.layers.8.self_attn.out_proj.bias', 'vision_model.encoder.layers.1.self_attn.k_proj.bias', 'vision_model.encoder.layers.7.self_attn.k_proj.weight', 'vision_model.encoder.layers.9.self_attn.v_proj.weight', 'vision_model.encoder.layers.10.mlp.fc2.weight', 'vision_model.encoder.layers.6.mlp.fc1.bias', 'vision_model.encoder.layers.0.self_attn.v_proj.weight', 'vision_model.encoder.layers.3.self_attn.out_proj.bias', 'vision_model.encoder.layers.4.mlp.fc2.bias', 'vision_model.encoder.layers.0.layer_norm1.bias', 'vision_model.embeddings.position_embedding.weight', 'vision_model.encoder.layers.10.layer_norm2.weight', 'vision_model.encoder.layers.1.self_attn.v_proj.weight', 'vision_model.encoder.layers.11.layer_norm1.weight', 'vision_model.encoder.layers.7.layer_norm2.weight', 'vision_model.encoder.layers.1.mlp.fc1.bias', 'vision_model.encoder.layers.11.mlp.fc2.bias', 'vision_model.encoder.layers.9.layer_norm2.weight', 'vision_model.encoder.layers.7.self_attn.k_proj.bias', 'vision_model.encoder.layers.1.self_attn.out_proj.bias', 'vision_model.encoder.layers.8.self_attn.v_proj.weight', 'vision_model.encoder.layers.6.self_attn.q_proj.weight', 'vision_model.encoder.layers.5.self_attn.out_proj.bias', 'vision_model.encoder.layers.1.layer_norm2.bias', 'vision_model.encoder.layers.9.mlp.fc2.bias', 'vision_model.embeddings.patch_embedding.weight', 'vision_model.encoder.layers.7.layer_norm1.bias', 'vision_model.encoder.layers.11.self_attn.q_proj.weight', 'vision_model.encoder.layers.8.layer_norm1.weight', 'vision_model.encoder.layers.6.self_attn.q_proj.bias', 'vision_model.encoder.layers.8.layer_norm2.bias', 'vision_model.encoder.layers.5.self_attn.q_proj.bias', 'vision_model.encoder.layers.11.self_attn.v_proj.weight', 'vision_model.encoder.layers.11.layer_norm2.weight', 'vision_model.encoder.layers.2.mlp.fc2.weight', 'vision_model.encoder.layers.0.self_attn.v_proj.bias', 'vision_model.encoder.layers.6.mlp.fc2.weight', 'vision_model.encoder.layers.8.self_attn.q_proj.bias', 'vision_model.encoder.layers.0.mlp.fc1.weight', 'vision_model.encoder.layers.8.self_attn.k_proj.bias', 'vision_model.encoder.layers.9.mlp.fc2.weight', 'vision_model.encoder.layers.4.self_attn.out_proj.weight', 'vision_model.encoder.layers.1.self_attn.q_proj.weight', 'vision_model.encoder.layers.11.self_attn.k_proj.weight', 'vision_model.encoder.layers.9.mlp.fc1.bias', 'vision_model.encoder.layers.2.layer_norm2.bias', 'vision_model.encoder.layers.10.self_attn.q_proj.bias', 'vision_model.encoder.layers.10.self_attn.k_proj.bias', 'vision_model.encoder.layers.0.self_attn.k_proj.bias', 'vision_model.encoder.layers.5.self_attn.v_proj.bias', 'vision_model.pre_layrnorm.weight', 'vision_model.encoder.layers.3.self_attn.k_proj.bias', 'vision_model.encoder.layers.1.layer_norm2.weight', 'vision_model.encoder.layers.10.layer_norm2.bias', 'vision_model.encoder.layers.4.self_attn.out_proj.bias', 'vision_model.encoder.layers.11.mlp.fc1.bias', 'vision_model.pre_layrnorm.bias', 'vision_model.encoder.layers.8.self_attn.v_proj.bias', 'vision_model.encoder.layers.5.layer_norm1.bias', 'vision_model.encoder.layers.6.layer_norm2.weight', 'vision_model.encoder.layers.4.mlp.fc1.weight', 'vision_model.encoder.layers.8.self_attn.out_proj.weight', 'vision_model.encoder.layers.7.self_attn.out_proj.weight', 'vision_model.encoder.layers.0.self_attn.out_proj.bias', 'vision_model.encoder.layers.6.mlp.fc2.bias', 'vision_model.encoder.layers.2.layer_norm2.weight', 'vision_model.encoder.layers.6.self_attn.v_proj.bias', 'vision_model.encoder.layers.0.self_attn.q_proj.bias', 'vision_model.encoder.layers.6.mlp.fc1.weight', 'vision_model.encoder.layers.11.self_attn.out_proj.weight', 'vision_model.encoder.layers.7.self_attn.v_proj.weight', 'vision_model.encoder.layers.8.mlp.fc1.weight', 'vision_model.encoder.layers.3.mlp.fc1.weight', 'vision_model.encoder.layers.1.mlp.fc2.bias', 'vision_model.encoder.layers.10.mlp.fc2.bias', 'vision_model.encoder.layers.6.layer_norm2.bias', 'vision_model.encoder.layers.4.layer_norm2.weight', 'vision_model.encoder.layers.3.layer_norm1.bias', 'vision_model.encoder.layers.0.mlp.fc2.bias', 'vision_model.encoder.layers.8.self_attn.q_proj.weight', 'vision_model.encoder.layers.3.self_attn.v_proj.weight', 'vision_model.encoder.layers.10.self_attn.q_proj.weight', 'vision_model.encoder.layers.2.self_attn.q_proj.weight', 'vision_model.encoder.layers.10.self_attn.out_proj.bias', 'vision_model.encoder.layers.11.self_attn.v_proj.bias', 'vision_model.encoder.layers.0.self_attn.k_proj.weight', 'vision_model.encoder.layers.4.layer_norm1.weight', 'vision_model.encoder.layers.0.self_attn.out_proj.weight', 'vision_model.embeddings.position_ids', 'vision_model.encoder.layers.4.mlp.fc1.bias', 'vision_model.encoder.layers.11.self_attn.out_proj.bias', 'vision_model.encoder.layers.3.mlp.fc2.weight', 'vision_model.encoder.layers.4.self_attn.q_proj.weight', 'vision_model.encoder.layers.2.mlp.fc2.bias', 'vision_model.encoder.layers.0.layer_norm2.weight', 'vision_model.encoder.layers.5.self_attn.k_proj.bias', 'vision_model.encoder.layers.7.self_attn.q_proj.bias', 'vision_model.encoder.layers.4.mlp.fc2.weight', 'vision_model.encoder.layers.1.self_attn.out_proj.weight', 'vision_model.encoder.layers.7.mlp.fc1.weight', 'vision_model.encoder.layers.10.self_attn.v_proj.weight', 'vision_model.encoder.layers.3.mlp.fc1.bias', 'vision_model.encoder.layers.2.layer_norm1.bias', 'vision_model.encoder.layers.9.layer_norm1.bias', 'vision_model.encoder.layers.10.self_attn.k_proj.weight', 'vision_model.encoder.layers.5.layer_norm2.bias', 'vision_model.encoder.layers.4.layer_norm1.bias', 'vision_model.encoder.layers.5.mlp.fc1.bias', 'logit_scale', 'vision_model.encoder.layers.4.self_attn.k_proj.bias', 'vision_model.encoder.layers.8.mlp.fc2.bias', 'vision_model.encoder.layers.10.self_attn.v_proj.bias', 'vision_model.encoder.layers.1.mlp.fc1.weight', 'vision_model.encoder.layers.5.self_attn.q_proj.weight', 'vision_model.encoder.layers.0.mlp.fc1.bias', 'vision_model.encoder.layers.2.self_attn.k_proj.bias', 'vision_model.encoder.layers.1.layer_norm1.weight', 'vision_model.encoder.layers.10.layer_norm1.weight', 'vision_model.encoder.layers.9.self_attn.out_proj.bias', 'vision_model.encoder.layers.8.layer_norm2.weight', 'vision_model.encoder.layers.7.layer_norm1.weight', 'vision_model.encoder.layers.6.self_attn.out_proj.weight', 'vision_model.encoder.layers.10.self_attn.out_proj.weight', 'vision_model.encoder.layers.7.layer_norm2.bias', 'vision_model.encoder.layers.10.mlp.fc1.weight', 'vision_model.encoder.layers.7.mlp.fc2.bias', 'vision_model.encoder.layers.11.layer_norm2.bias', 'vision_model.encoder.layers.1.self_attn.v_proj.bias', 'vision_model.encoder.layers.2.self_attn.k_proj.weight', 'vision_model.encoder.layers.11.mlp.fc2.weight', 'vision_model.encoder.layers.9.self_attn.out_proj.weight', 'vision_model.encoder.layers.2.self_attn.out_proj.bias', 'vision_model.encoder.layers.7.self_attn.q_proj.weight', 'vision_model.encoder.layers.2.layer_norm1.weight', 'vision_model.encoder.layers.5.mlp.fc2.weight', 'vision_model.encoder.layers.4.self_attn.v_proj.weight', 'vision_model.encoder.layers.5.self_attn.v_proj.weight', 'vision_model.encoder.layers.11.mlp.fc1.weight', 'vision_model.encoder.layers.11.self_attn.q_proj.bias', 'vision_model.encoder.layers.0.self_attn.q_proj.weight', 'vision_model.encoder.layers.7.mlp.fc1.bias', 'vision_model.encoder.layers.7.self_attn.out_proj.bias', 'vision_model.encoder.layers.2.self_attn.out_proj.weight', 'vision_model.encoder.layers.8.mlp.fc1.bias', 'vision_model.encoder.layers.9.layer_norm2.bias', 'vision_model.post_layernorm.weight', 'vision_model.encoder.layers.1.mlp.fc2.weight', 'vision_model.encoder.layers.8.mlp.fc2.weight', 'vision_model.encoder.layers.5.layer_norm2.weight', 'vision_model.encoder.layers.6.self_attn.out_proj.bias', 'vision_model.encoder.layers.3.self_attn.q_proj.bias', 'vision_model.post_layernorm.bias', 'vision_model.encoder.layers.6.layer_norm1.bias', 'vision_model.embeddings.class_embedding', 'vision_model.encoder.layers.9.self_attn.q_proj.weight', 'vision_model.encoder.layers.9.mlp.fc1.weight', 'vision_model.encoder.layers.5.self_attn.out_proj.weight', 'vision_model.encoder.layers.6.self_attn.k_proj.weight', 'vision_model.encoder.layers.2.self_attn.v_proj.weight', 'vision_model.encoder.layers.5.self_attn.k_proj.weight', 'vision_model.encoder.layers.9.self_attn.v_proj.bias', 'vision_model.encoder.layers.3.layer_norm1.weight', 'vision_model.encoder.layers.11.layer_norm1.bias', 'vision_model.encoder.layers.0.mlp.fc2.weight', 'vision_model.encoder.layers.6.self_attn.k_proj.bias', 'vision_model.encoder.layers.8.self_attn.k_proj.weight', 'vision_model.encoder.layers.6.self_attn.v_proj.weight', 'vision_model.encoder.layers.3.self_attn.v_proj.bias', 'vision_model.encoder.layers.3.self_attn.k_proj.weight', 'vision_model.encoder.layers.3.layer_norm2.bias', 'vision_model.encoder.layers.2.self_attn.v_proj.bias', 'vision_model.encoder.layers.6.layer_norm1.weight', 'vision_model.encoder.layers.4.layer_norm2.bias', 'vision_model.encoder.layers.3.self_attn.out_proj.weight', 'vision_model.encoder.layers.7.self_attn.v_proj.bias', 'vision_model.encoder.layers.10.mlp.fc1.bias', 'vision_model.encoder.layers.3.mlp.fc2.bias', 'vision_model.encoder.layers.3.layer_norm2.weight', 'vision_model.encoder.layers.0.layer_norm2.bias', 'vision_model.encoder.layers.8.layer_norm1.bias', 'vision_model.encoder.layers.9.self_attn.k_proj.bias', 'vision_model.encoder.layers.3.self_attn.q_proj.weight', 'vision_model.encoder.layers.5.layer_norm1.weight', 'vision_model.encoder.layers.2.self_attn.q_proj.bias', 'vision_model.encoder.layers.7.mlp.fc2.weight', 'text_projection.weight', 'vision_model.encoder.layers.9.self_attn.k_proj.weight', 'vision_model.encoder.layers.1.self_attn.k_proj.weight', 'vision_model.encoder.layers.1.self_attn.q_proj.bias', 'vision_model.encoder.layers.2.mlp.fc1.bias', 'vision_model.encoder.layers.4.self_attn.q_proj.bias', 'vision_model.encoder.layers.5.mlp.fc1.weight'] - This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
CLIPTextModel(
(text_model): CLIPTextTransformer(
(embeddings): CLIPTextEmbeddings(
(token_embedding): Embedding(49408, 512)
(position_embedding): Embedding(77, 512)
)
(encoder): CLIPEncoder(
(layers): ModuleList(
(0): CLIPEncoderLayer(
(self_attn): CLIPAttention(
(k_proj): Linear(in_features=512, out_features=512, bias=True)
(v_proj): Linear(in_features=512, out_features=512, bias=True)
(q_proj): Linear(in_features=512, out_features=512, bias=True)
(out_proj): Linear(in_features=512, out_features=512, bias=True)
)
(layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP(
(activation_fn): QuickGELUActivation()
(fc1): Linear(in_features=512, out_features=2048, bias=True)
(fc2): Linear(in_features=2048, out_features=512, bias=True)
)
(layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
(1): CLIPEncoderLayer(
(self_attn): CLIPAttention(
(k_proj): Linear(in_features=512, out_features=512, bias=True)
(v_proj): Linear(in_features=512, out_features=512, bias=True)
(q_proj): Linear(in_features=512, out_features=512, bias=True)
(out_proj): Linear(in_features=512, out_features=512, bias=True)
)
(layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP(
(activation_fn): QuickGELUActivation()
(fc1): Linear(in_features=512, out_features=2048, bias=True)
(fc2): Linear(in_features=2048, out_features=512, bias=True)
)
(layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
(2): CLIPEncoderLayer(
(self_attn): CLIPAttention(
(k_proj): Linear(in_features=512, out_features=512, bias=True)
(v_proj): Linear(in_features=512, out_features=512, bias=True)
(q_proj): Linear(in_features=512, out_features=512, bias=True)
(out_proj): Linear(in_features=512, out_features=512, bias=True)
)
(layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP(
(activation_fn): QuickGELUActivation()
(fc1): Linear(in_features=512, out_features=2048, bias=True)
(fc2): Linear(in_features=2048, out_features=512, bias=True)
)
(layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
(3): CLIPEncoderLayer(
(self_attn): CLIPAttention(
(k_proj): Linear(in_features=512, out_features=512, bias=True)
(v_proj): Linear(in_features=512, out_features=512, bias=True)
(q_proj): Linear(in_features=512, out_features=512, bias=True)
(out_proj): Linear(in_features=512, out_features=512, bias=True)
)
(layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP(
(activation_fn): QuickGELUActivation()
(fc1): Linear(in_features=512, out_features=2048, bias=True)
(fc2): Linear(in_features=2048, out_features=512, bias=True)
)
(layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
(4): CLIPEncoderLayer(
(self_attn): CLIPAttention(
(k_proj): Linear(in_features=512, out_features=512, bias=True)
(v_proj): Linear(in_features=512, out_features=512, bias=True)
(q_proj): Linear(in_features=512, out_features=512, bias=True)
(out_proj): Linear(in_features=512, out_features=512, bias=True)
)
(layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP(
(activation_fn): QuickGELUActivation()
(fc1): Linear(in_features=512, out_features=2048, bias=True)
(fc2): Linear(in_features=2048, out_features=512, bias=True)
)
(layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
(5): CLIPEncoderLayer(
(self_attn): CLIPAttention(
(k_proj): Linear(in_features=512, out_features=512, bias=True)
(v_proj): Linear(in_features=512, out_features=512, bias=True)
(q_proj): Linear(in_features=512, out_features=512, bias=True)
(out_proj): Linear(in_features=512, out_features=512, bias=True)
)
(layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP(
(activation_fn): QuickGELUActivation()
(fc1): Linear(in_features=512, out_features=2048, bias=True)
(fc2): Linear(in_features=2048, out_features=512, bias=True)
)
(layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
(6): CLIPEncoderLayer(
(self_attn): CLIPAttention(
(k_proj): Linear(in_features=512, out_features=512, bias=True)
(v_proj): Linear(in_features=512, out_features=512, bias=True)
(q_proj): Linear(in_features=512, out_features=512, bias=True)
(out_proj): Linear(in_features=512, out_features=512, bias=True)
)
(layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP(
(activation_fn): QuickGELUActivation()
(fc1): Linear(in_features=512, out_features=2048, bias=True)
(fc2): Linear(in_features=2048, out_features=512, bias=True)
)
(layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
(7): CLIPEncoderLayer(
(self_attn): CLIPAttention(
(k_proj): Linear(in_features=512, out_features=512, bias=True)
(v_proj): Linear(in_features=512, out_features=512, bias=True)
(q_proj): Linear(in_features=512, out_features=512, bias=True)
(out_proj): Linear(in_features=512, out_features=512, bias=True)
)
(layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP(
(activation_fn): QuickGELUActivation()
(fc1): Linear(in_features=512, out_features=2048, bias=True)
(fc2): Linear(in_features=2048, out_features=512, bias=True)
)
(layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
(8): CLIPEncoderLayer(
(self_attn): CLIPAttention(
(k_proj): Linear(in_features=512, out_features=512, bias=True)
(v_proj): Linear(in_features=512, out_features=512, bias=True)
(q_proj): Linear(in_features=512, out_features=512, bias=True)
(out_proj): Linear(in_features=512, out_features=512, bias=True)
)
(layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP(
(activation_fn): QuickGELUActivation()
(fc1): Linear(in_features=512, out_features=2048, bias=True)
(fc2): Linear(in_features=2048, out_features=512, bias=True)
)
(layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
(9): CLIPEncoderLayer(
(self_attn): CLIPAttention(
(k_proj): Linear(in_features=512, out_features=512, bias=True)
(v_proj): Linear(in_features=512, out_features=512, bias=True)
(q_proj): Linear(in_features=512, out_features=512, bias=True)
(out_proj): Linear(in_features=512, out_features=512, bias=True)
)
(layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP(
(activation_fn): QuickGELUActivation()
(fc1): Linear(in_features=512, out_features=2048, bias=True)
(fc2): Linear(in_features=2048, out_features=512, bias=True)
)
(layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
(10): CLIPEncoderLayer(
(self_attn): CLIPAttention(
(k_proj): Linear(in_features=512, out_features=512, bias=True)
(v_proj): Linear(in_features=512, out_features=512, bias=True)
(q_proj): Linear(in_features=512, out_features=512, bias=True)
(out_proj): Linear(in_features=512, out_features=512, bias=True)
)
(layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP(
(activation_fn): QuickGELUActivation()
(fc1): Linear(in_features=512, out_features=2048, bias=True)
(fc2): Linear(in_features=2048, out_features=512, bias=True)
)
(layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
(11): CLIPEncoderLayer(
(self_attn): CLIPAttention(
(k_proj): Linear(in_features=512, out_features=512, bias=True)
(v_proj): Linear(in_features=512, out_features=512, bias=True)
(q_proj): Linear(in_features=512, out_features=512, bias=True)
(out_proj): Linear(in_features=512, out_features=512, bias=True)
)
(layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP(
(activation_fn): QuickGELUActivation()
(fc1): Linear(in_features=512, out_features=2048, bias=True)
(fc2): Linear(in_features=2048, out_features=512, bias=True)
)
(layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
)
)
(final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
)
In the following section defines a lot of helpful classes and methods we use throughout the notebook to prepare, modify, and fit the data. Explanations for each sub-section are provided.
Some utility functions to reduce code chunks and streamline some across-subject operations.
The function below plots the neural or feature RDMs and allows different zooms. Additionally, we defined a function computing the correlation between neural and feature RDMs.
def plot_rdm(neural_rdms: dict = None, layer_features: dict = None, zoom: list = None):
"""Plots neural or feature RDMs for given subjects or models
Args:
neural_rdms: dictionary containing the neural RDMs for the left and right hemisphere
layer_features: dictionary containing the layer features
zoom: list containing the range of the x and y axis"""
rdm_n, layers, size = (len(layer_features), list(layer_features.keys()), 25) if layer_features is not None \
else (len(neural_rdms), list(neural_rdms.keys()), 10)
fig, ax = plt.subplots(1, rdm_n, figsize=(size, size))
plt.subplots_adjust(wspace=0.2)
for i, layer in enumerate(layers):
if layer_features is not None:
rdm = euclidean_distances(torch.tensor(layer_features[layer]).flatten(1).numpy())
ax[i].set_xlabel(f"Correlation wit neural RDM: {corr_rdm(neural_rdms, rdm):.2f}")
else:
rdm = neural_rdms[layer]
ax[i].set_xlabel("Trials", fontsize=8)
ax[i].set_title(layer, fontsize=10)
ax[i].imshow(rdm)
if zoom is not None:
ax[i].set_xlim(zoom)
ax[i].set_ylim([zoom[-1], zoom[0]])
plt.show()
def corr_rdm(neural_rdm: dict = None, layer_rdm: float = None):
"""Function to compute the correlation between neural and feature RDMs
Args:
neural_rdm: dictionary containing the neural RDMs for the left and right hemisphere
layer_rdm: array containing the layer RDM"""
layer_rdv = squareform(layer_rdm.round(5))
neural_rdv = [squareform(n_rdm.round(5)) for n_rdm in neural_rdm.values()]
return np.mean([spearmanr(n_rdv, layer_rdv)[0] for n_rdv in neural_rdv])
The function below computes and prepares the neural RDMs from neural observations corresponding to the shared images (in the training data).
def neural_rdm_shared(subs: str = [f"subj0{i}" for i in range(1, 9)]):
""""Function to compute the shared neural RDMs for the specified subjects
Args:
subs: list of strings indicating a subjects identification"""
train_cap_file = pd.read_csv('data/algonauts_2023_caption_data.csv')
# Save all shared trials per subject
shared_img = []
for sub in subs:
dirs = Subject(sub)
dirs.load_image_paths()
img_match = [int(i[-9:-4]) for i in dirs.train_img_list]
sub_df = train_cap_file[(train_cap_file['subject'] == sub) & (train_cap_file['nsdId'].isin(img_match))].reset_index(drop=True)
shared_img.append(sub_df[(sub_df['n'] == 8)][['nsdId']])
# Counting the shared images that occur in all subjects training split and saving the IDs
counts_id = pd.concat(shared_img, axis=0, join="inner").groupby("nsdId").size()
keep_id = counts_id[counts_id == 8].index.tolist()
# Calculating the neural RDMs based on the retained IDs and corresponding indexes
neural_rdms_shared = []
for i, sub in enumerate(subs):
dirs = Subject(sub)
dirs.load_neural_data()
shared_idx = shared_img[i][shared_img[i]["nsdId"].isin(keep_id)].index.values
lh_fmri_shared, rh_fmri_shared = dirs.lh_fmri[shared_idx], dirs.rh_fmri[shared_idx]
neural_rdms_shared.append(np.stack([euclidean_distances(lh_fmri_shared), euclidean_distances(rh_fmri_shared)]))
return np.stack(neural_rdms_shared)
The following function returns the index of a given ROI class.
def roi_class_index(roi: str = "V1v"):
""""Function to return the roi class index given the specified roi
Args:
subs: string indicating the roi"""
if roi in ["V1v", "V1d", "V2v", "V2d", "V3v", "V3d", "hV4"]:
roi_class = 0
elif roi in ["EBA", "FBA-1", "FBA-2", "mTL-bodies"]:
roi_class = 1
elif roi in ["OFA", "FFA-1", "FFA-2", "mTL-faces", "aTL-faces"]:
roi_class = 2
elif roi in ["OPA", "PPA", "RSC"]:
roi_class = 3
elif roi in ["OWFA", "VWFA-1", "VWFA-2", "mfs-words", "mTL-words"]:
roi_class = 4
elif roi in ["early", "midventral", "midlateral", "midparietal", "ventral", "lateral", "parietal"]:
roi_class = 5
else:
roi_class = 6
return roi_class
The ImageDataset and TextDataset classes are used to create datasets for PyTorch dataloaders. The ImageDataset is used for the NSD images and the TextDataset for the captions of the corresponding COCO images.
class ImageDataset(Dataset):
"""Class to prepare the image data for the PyTorch DataLoader"""
def __init__(self, image_list, processor):
self.image_list = image_list
self.processor = processor
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
image = Image.open(self.image_list[idx])
image = self.processor(images=image, return_tensors="pt", padding=True)
return image["pixel_values"].squeeze()
class TextDataset(Dataset):
""""Class to prepare the text data for the PyTorch DataLoader"""
def __init__(self, text, max_length, processor):
self.text = processor(text=text, return_tensors="pt", padding="max_length", max_length=max_length)
def __len__(self):
return len(self.text["input_ids"])
def __getitem__(self, idx):
return self.text["input_ids"][idx]
The Subject class is initialized with a valid subject id (e.g., "subj01"). It stores all relevant paths and can load the data for the given subject.
class Subject:
"""Class to access all relevant data for a given subject"""
def __init__(self, subject="subj01"):
assert subject in ["subj01", "subj02", "subj03", "subj04", "subj05", "subj06", "subj07", "subj08",], "Invalid subject"
self.subject = subject
self.data_dir = "data/algonauts_2023_challenge_data"
self.training_images_dir = f"{self.data_dir}/{subject}/training_split/training_images"
self.test_images_dir = f"{self.data_dir}/{subject}/test_split/test_images"
self.training_fmri_dir = f"{self.data_dir}/{subject}/training_split/training_fmri"
self.roi_masks_dir = f"{self.data_dir}/{subject}/roi_masks"
self.submission_dir = f"algonauts_2023_challenge_submission"
# Load these as needed
self.train_img_list = None
self.test_img_list = None
self.train_cap_list = None
self.test_cap_list = None
self.lh_fmri = None
self.rh_fmri = None
self.lh_roi_masks = None
self.rh_roi_masks = None
self.roi_name_maps = None
self.lh_challenge_rois = None
self.rh_challenge_rois = None
self.train_img_dataloader = None
self.test_img_dataloader = None
self.train_cap_dataloader = None
self.test_cap_dataloader = None
def load_image_paths(self) -> None:
"""Loads the image paths from the training and test directories"""
self.train_img_list = glob.glob(f"{self.training_images_dir}/*.png")
self.train_img_list.sort()
self.test_img_list = glob.glob(f"{self.test_images_dir}/*.png")
self.test_img_list.sort()
# print(f"Training images: {len(self.train_img_list)}")
# print(f"Test images: {len(self.test_img_list)}")
def load_captions(self) -> None:
"""Loads and matches the captions from the csv file"""
if self.train_img_list is None:
self.load_image_paths()
train_cap_file = pd.read_csv('data/algonauts_2023_caption_data.csv')
img_match = [int(i[-9:-4]) for i in self.train_img_list]
self.train_cap_list = train_cap_file[(train_cap_file['subject'] == self.subject) & (train_cap_file['nsdId'].isin(img_match))]['caption'].tolist()
self.test_cap_list = train_cap_file[(train_cap_file['subject'] == self.subject) & (~train_cap_file['nsdId'].isin(img_match))]['caption'].tolist()
# print(f"Training captions: {len(self.train_cap_list)}")
# print(f"Test captions: {len(self.test_cap_list)}")
def load_neural_data(self) -> None:
"""Loads the neural data from the .npy files"""
self.lh_fmri = np.load(f"{self.training_fmri_dir}/lh_training_fmri.npy")
self.rh_fmri = np.load(f"{self.training_fmri_dir}/rh_training_fmri.npy")
# print(f"Left hemisphere neural data loaded. Shape: {self.lh_fmri.shape}")
# print(f"Right hemisphere neural data loaded. Shape: {self.rh_fmri.shape}")
def create_dataloaders(self, processor, batch_size) -> None:
"""Creates the dataloaders for the images and captions"""
if self.train_img_list is None:
self.load_image_paths()
if self.train_cap_list is None:
self.load_captions()
max_caption_len = processor(text=self.train_cap_list + self.test_cap_list, return_tensors="pt", padding=True)["input_ids"].shape[1]
train_txt_dataset = TextDataset(self.train_cap_list, max_caption_len, processor)
test_txt_dataset = TextDataset(self.test_cap_list, max_caption_len, processor)
train_img_dataset = ImageDataset(self.train_img_list, processor)
test_img_dataset = ImageDataset(self.test_img_list, processor)
self.train_img_dataloader = DataLoader(train_img_dataset, batch_size=batch_size, shuffle=False)
self.test_img_dataloader = DataLoader(test_img_dataset, batch_size=batch_size, shuffle=False)
self.train_txt_dataloader = DataLoader(train_txt_dataset, batch_size=batch_size, shuffle=False)
self.test_txt_dataloader = DataLoader(test_txt_dataset, batch_size=batch_size, shuffle=False)
print(f"Train image dataloader: {len(self.train_img_dataloader)} batches")
print(f"Test image dataloader: {len(self.test_img_dataloader)} batches")
print(f"Train caption dataloader: {len(self.train_txt_dataloader)} batches")
print(f"Test caption dataloader: {len(self.test_txt_dataloader)} batches")
def load_challenge_rois(self) -> None:
"""Loads the challenge rois from the .npy files"""
# Load the ROI classes mapping dictionaries
roi_mapping_files = ['mapping_prf-visualrois.npy', 'mapping_floc-bodies.npy',
'mapping_floc-faces.npy', 'mapping_floc-places.npy',
'mapping_floc-words.npy', 'mapping_streams.npy']
self.roi_name_maps = []
for r in roi_mapping_files:
self.roi_name_maps.append(np.load(f"{self.roi_masks_dir}/{r}", allow_pickle=True).item())
# Load the ROI brain surface maps
lh_challenge_roi_files = ['lh.prf-visualrois_challenge_space.npy',
'lh.floc-bodies_challenge_space.npy', 'lh.floc-faces_challenge_space.npy',
'lh.floc-places_challenge_space.npy', 'lh.floc-words_challenge_space.npy',
'lh.streams_challenge_space.npy']
rh_challenge_roi_files = ['rh.prf-visualrois_challenge_space.npy',
'rh.floc-bodies_challenge_space.npy', 'rh.floc-faces_challenge_space.npy',
'rh.floc-places_challenge_space.npy', 'rh.floc-words_challenge_space.npy',
'rh.streams_challenge_space.npy']
self.lh_challenge_rois = []
self.rh_challenge_rois = []
for r in range(len(lh_challenge_roi_files)):
self.lh_challenge_rois.append(np.load(f"{self.roi_masks_dir}/{lh_challenge_roi_files[r]}"))
self.rh_challenge_rois.append(np.load(f"{self.roi_masks_dir}/{rh_challenge_roi_files[r]}"))
def load_roi_masks(self, roi="V1v", hemisphere="lh"):
valid_roi = ["V1v", "V1d", "V2v", "V2d", "V3v",
"V3d", "hV4", "EBA", "FBA-1", "FBA-2",
"mTL-bodies", "OFA", "FFA-1", "FFA-2",
"mTL-faces", "aTL-faces", "OPA", "PPA",
"RSC", "OWFA", "VWFA-1", "VWFA-2", "mfs-words",
"mTL-words", "early", "midventral", "midlateral",
"midparietal", "ventral", "lateral", "parietal",
"all-vertices"]
valid_hemisphere = ["lh", "rh"]
assert roi in valid_roi, "Invalid ROI"
assert hemisphere in valid_hemisphere, "Invalid hemisphere"
# Define the ROI class based on the selected ROI
roi_class = ['prf-visualrois', 'floc-bodies', 'floc-faces', 'floc-places', 'floc-words', 'streams' , 'all-vertices'][roi_class_index(roi)]
roi_class_dir = f"{hemisphere}.{roi_class}_fsaverage_space.npy"
roi_map_dir = f"mapping_{roi_class}.npy"
fsaverage_roi_class = np.load(f"{self.roi_masks_dir}/{roi_class_dir}")
roi_map = None
if roi != "all-vertices":
roi_map = np.load(f"{self.roi_masks_dir}/{roi_map_dir}", allow_pickle=True).item()
return fsaverage_roi_class, roi_map
The CLIPFeatureExtractor class is used to extract the hidden states from any CLIP model.
class CLIPFeatureExtractor():
"""Extracts the features from hidden states of a CLIP model."""
def __init__(
self,
idxs: list = [i for i in range(13)], # hidden layer indices to extract features from. Standard CLIP has an embedding layer and 12 transformer layers.
last_hidden_layer: bool = False, # whether to extract features from the last hidden layer
model: PreTrainedModel = None, # CLIP model
dataloader: DataLoader = None, # dataloader for batching
) -> None:
self.idxs = idxs
self.last_hidden_layer = last_hidden_layer
self.generate_feature_dict()
if self.last_hidden_layer:
self.idxs.append(13) # adds an additional idx to allow for loop zip()
self.model = model
self.dataloader = dataloader
print(f"idxs: {self.idxs}")
print(f"feature dict keys: {self.feature_dict.keys()}")
def generate_feature_dict(self) -> None:
"""Generates a feature dict according to the idxs and last_hidden_layer attributes."""
feature_dict = {}
for idx in self.idxs:
if idx == 0:
feature_dict["Embedding Layer"] = None
else:
feature_dict[f"Transformer Layer {idx}"] = None
if self.last_hidden_layer:
feature_dict["Final Layer"] = None
self.feature_dict = feature_dict
def concat_features(self, features: dict) -> None:
"""Adds extracted features to the feature dict.
Args:
features: features extracted from the output of a CLIP model"""
keys = list(self.feature_dict.keys())
# check if feature_dict is empty
if self.feature_dict[keys[0]] is None:
self.feature_dict = features
else:
for key in keys:
self.feature_dict[key] = np.concatenate((self.feature_dict[key], features[key]), axis=0)
def extract_raw_features(self, output) -> None:
"""Extracts features from the hidden states of a CLIP model and concates them to the feature_dict.
Args:
output: output of a CLIP model
"""
features = {}
for idx, key in zip(self.idxs, self.feature_dict.keys()):
if key == "Final Layer":
features[key] = output.last_hidden_state.cpu().detach().numpy()
else:
features[key] = output.hidden_states[idx].cpu().detach().numpy()
self.concat_features(features)
def extract_raw_features_from_model(self) -> None:
"""Runs the CLIP model on the dataloader and extracts features from the hidden states."""
self.model = self.model.to(device)
with torch.no_grad():
for batch in tqdm(self.dataloader):
batch = batch.to(device)
output = self.model(batch, output_hidden_states=True)
self.extract_raw_features(output)
batch = None # clear batch from memory
output = None # clear output from memory
self.model = self.model.to("cpu")
The KFoldProcedure class is used to define a procedure for each fold of the k-fold validation. It can be supplied to a KFold class which executes its run() function on all folds.
class KFoldProcedure:
"""This class is used to define a procedure that is run on each fold of a k-fold cross validation."""
def __init__(self) -> None:
assert isinstance(self.model_name, str) and len(self.model_name) > 0, "Please define a model name as part of the KFold Procedure."
assert isinstance(self.description, str ) and len(self.description) > 0 , "Please define a description as part of the KFold Procedure."
def prepare(self) -> None:
"""Operations that should be executed before the fold loop"""
raise NotImplementedError
def run(self, train_idxs: np.ndarray, val_idxs: np.ndarray) -> Dict[str, Dict[str, np.ndarray]]:
"""This should return a dict of correlations.
dict format: {"layer": {"lh": np.ndarray, "lh": np.ndarray}}"""
raise NotImplementedError
def return_idxs(self):
"""Returns idxs to create folds in the KFold class."""
raise NotImplementedError
def return_roi_names(self) -> List[str]:
"""Required for the plot function in the KFold class."""
return self.roi_names
def return_model_name_and_description(self) -> Tuple[str, str]:
return self.model_name, self.description
def calculate_correlations(self, lh_pred, rh_pred, lh_fmri, rh_fmri) -> Tuple[np.ndarray, np.ndarray]:
"""Calculate correlation between prediction and fmri activation"""
lh_correlation = np.zeros(lh_pred.shape[1])
for v in tqdm(range(lh_pred.shape[1])):
lh_correlation[v] = corr(lh_pred[:,v], lh_fmri[:,v])[0]
# Right hemisphere
rh_correlation = np.zeros(rh_pred.shape[1])
for v in tqdm(range(rh_pred.shape[1])):
rh_correlation[v] = corr(rh_pred[:,v], rh_fmri[:,v])[0]
return lh_correlation, rh_correlation
def calculate_median_correlations(self, lh_correlation, rh_correlation) -> Tuple[np.ndarray, np.ndarray]:
"""Calculate median correlation for each ROI."""
# Select the correlation results vertices of each ROI
lh_challange_rois = self.lh_challenge_rois
rh_challange_rois = self.rh_challenge_rois
self.roi_names = []
lh_roi_correlation = []
rh_roi_correlation = []
for r1 in range(len(lh_challange_rois)):
for r2 in self.roi_name_maps[r1].items():
if r2[0] != 0: # zeros indicate to vertices falling outside the ROI of interest
self.roi_names.append(r2[1])
lh_roi_idx = np.where(lh_challange_rois[r1] == r2[0])[0]
rh_roi_idx = np.where(rh_challange_rois[r1] == r2[0])[0]
lh_roi_correlation.append(lh_correlation[lh_roi_idx])
rh_roi_correlation.append(rh_correlation[rh_roi_idx])
self.roi_names.append('All vertices')
lh_roi_correlation.append(lh_correlation)
rh_roi_correlation.append(rh_correlation)
lh_median_roi_correlation = [np.median(lh_roi_correlation[r])
for r in range(len(lh_roi_correlation))]
rh_median_roi_correlation = [np.median(rh_roi_correlation[r])
for r in range(len(rh_roi_correlation))]
return lh_median_roi_correlation, rh_median_roi_correlation
class KFold:
"""Run a k-fold cross validation with a given procedure."""
def __init__(self, folds: int = 8, seed: int = 5, procedure: KFoldProcedure = None) -> None:
assert folds > 1, "folds must be greater than 1"
assert seed > 0, "seed must be greater than 0"
assert isinstance(folds, int), "folds must be an integer"
assert isinstance(seed, int), "seed must be an integer"
#assert isinstance(procedure, KFoldProcedure), "procedure must be an instance of KFoldProcedure"
self.folds = folds
self.seed = seed
self.procedure = procedure
self.fold_correlations = {}
self.mean_correlations = None
def run(self) -> None:
"""Runs the procedure on each fold and accesses the correlations."""
self.procedure.prepare()
# Create k folds
fold_idxs = self.procedure.return_idxs()
np.random.seed(self.seed)
np.random.shuffle(fold_idxs)
self.fold_idxs = np.array_split(fold_idxs, self.folds)
for fold in range(self.folds):
# Select validation and train set
val_idxs = self.fold_idxs[fold]
train_idxs = np.concatenate([self.fold_idxs[j] for j in range(self.folds) if j != fold])
# Info for current fold
print(f"#############################################")
print(f"# Fold: {fold+1}/ {self.folds}")
print(f"# Train size: {len(train_idxs)}")
print(f"# Validation size: {len(val_idxs)}")
print(f"#############################################")
# Run procedure
self.fold_correlations[fold] = self.procedure.run(train_idxs, val_idxs)
# Get model name and description
self.model_name, self.description = self.procedure.return_model_name_and_description()
self.roi_names = self.procedure.return_roi_names()
self.calculate_mean_accross_folds()
self.mean_correlations_to_csv()
def calculate_mean_accross_folds(self):
"""Calculates the mean across folds for each layer"""
self.mean_correlations = {}
for layer in self.fold_correlations[0].keys():
self.mean_correlations[layer] = {}
for hemi in self.fold_correlations[0][layer].keys():
self.mean_correlations[layer][hemi] = np.nanmean([self.fold_correlations[fold][layer][hemi] for fold in range(self.folds)], axis=0)
def mean_correlations_to_csv(self) -> None:
df = pd.DataFrame(columns=["model", "layer", "hemisphere", "roi", "correlation"])
for layer in self.mean_correlations.keys():
for hemisphere in self.mean_correlations[layer].keys():
for i in range(len(self.roi_names)):
df = df.append({"model": self.model_name, "layer": layer, "hemisphere": hemisphere, "roi": self.roi_names[i], "correlation": self.mean_correlations[layer][hemisphere][i]}, ignore_index=True)
validations = glob.glob(f"validations/validation*")
if len(validations) == 0:
# create first validation folder
folder_name = "validation001"
os.mkdir(f"validations/{folder_name}")
else:
# create next validation folder
last_validation = sorted(validations)[-1]
last_validation_number = int(last_validation.split("/")[-1].split("validation")[-1])
next_validation_number = last_validation_number + 1
folder_name = f"validation{str(next_validation_number).zfill(3)}"
os.mkdir(f"validations/{folder_name}")
# Write text file with model description
with open(f"validations/{folder_name}/info.txt", "w") as f:
f.write(self.description)
# Save dataframe
df.to_csv(f"validations/{folder_name}/results.csv", index=False)
Additionally we define a function to plot the validation results.
def plot_kfold_result(validation = "001"):
"""Plots the validation results from the csv file in the given validaiton folder."""
folder = f"validations/validation{validation}"
with open(f"{folder}/info.txt", 'r') as f:
info = f.read()
df = pd.read_csv(f"{folder}/results.csv")
# drop model column
df = df.drop("model", axis=1)
# Define color palette and assign colors to layers
palette = sns.color_palette("colorblind", 14)
layer_colors = {layer: palette[i] for i, layer in enumerate(df.layer.unique())}
# Split data into left and right hemispheres
left_df = df[df['hemisphere'] == 'lh']
right_df = df[df['hemisphere'] == 'rh']
# Create bar plots
fig, axes = plt.subplots(2, 1, figsize=(30, 10))
fig.suptitle(info)
plt.subplots_adjust(hspace=0.5)
bar_width = 0.05
# Plot left hemisphere data
for i, layer in enumerate(df.layer.unique()):
layer_data = left_df[left_df['layer'] == layer]
x = np.arange(len(layer_data['roi']))
# center bars around xtick
axes[0].bar(x - len(df.layer.unique())/2 * 0.05 + i * 0.05, layer_data['correlation'], width=bar_width, label=layer, color=layer_colors[layer])
axes[0].margins(x=0.01) # reduce white space before first x tick
axes[0].set_xticks([i for i in range(len(layer_data['roi']))])
axes[0].set_xticklabels([roi for roi in layer_data['roi']])
axes[0].set_title('Left Hemisphere')
axes[0].set_xlabel('ROI')
axes[0].tick_params(axis='x', labelrotation=45)
axes[0].set_ylabel('Correlation')
axes[0].legend(loc='upper center', bbox_to_anchor=(0.5, 1), ncol=5)
axes[0].set_ylim([0, 1])
row_colors = [layer_colors[layer] for layer in left_df.groupby('layer').mean().sort_values(by='correlation', ascending=False).index]
axes[0].table(cellText=left_df.groupby('layer').mean().sort_values(by='correlation', ascending=False).round(3).values, rowLabels=df.groupby('layer').mean().sort_values(by='correlation', ascending=False).index, colLabels=["Mean Correlation"], rowColours=row_colors, bbox = [0.95, 0.5, 0.05, 0.5])
# Plot right hemisphere data
for i, layer in enumerate(df.layer.unique()):
layer_data = right_df[right_df['layer'] == layer]
x = np.arange(len(layer_data['roi']))
axes[1].bar(x - len(df.layer.unique())/2 * 0.05 + i * 0.05, layer_data['correlation'], width=bar_width, label=layer, color=layer_colors[layer])
axes[1].margins(x=0.01)
axes[1].set_xticks([i for i in range(len(layer_data['roi']))])
axes[1].set_xticklabels([roi for roi in layer_data['roi']])
axes[1].set_title('Right Hemisphere')
axes[1].set_xlabel('ROI')
axes[1].tick_params(axis='x', labelrotation=45)
axes[1].set_ylabel('Correlation')
axes[1].legend(loc='upper center', bbox_to_anchor=(0.5, 1), ncol=5)
axes[1].set_ylim([0, 1])
row_colors = [layer_colors[layer] for layer in right_df.groupby('layer').mean().sort_values(by='correlation', ascending=False).index]
axes[1].table(cellText=right_df.groupby('layer').mean().sort_values(by='correlation', ascending=False).round(3).values, rowLabels=df.groupby('layer').mean().sort_values(by='correlation', ascending=False).index, colLabels=["Mean Correlation"], rowColours=row_colors, bbox = [0.95, 0.5, 0.05, 0.5])
plt.show()
Which subject would you like to investigate?
# Select the subject
subject = "subj01" # ["subj01", "subj02", "subj03", "subj04", "subj05", "subj06", "subj07", "subj08"]
dirs = Subject(subject)
First we load the fMRI training split data of the selected subject. The fMRI data consists of two '.npy' files:
lh_training_fmri.npy: the left hemisphere (LH) fMRI data.rh_training_fmri.npy: the right hemisphere (RH) fMRI data.Both files are 2-dimensional arrays with training stimulus images as rows and fMRI vertices as columns.
# Load Neural data
dirs.load_neural_data()
lh_fmri, rh_fmri = dirs.lh_fmri, dirs.rh_fmri
print(f"Left hemisphere neural data loaded. Shape: {lh_fmri.shape}")
print(f"Right hemisphere neural data loaded. Shape: {rh_fmri.shape}")
Left hemisphere neural data loaded. Shape: (9841, 19004) Right hemisphere neural data loaded. Shape: (9841, 20544)
Let's look at the neural RDMs for both hemispheres of subject 1.
# Store the neural RDMs
neural_rdms = {'Left Hemisphere RDM': euclidean_distances(lh_fmri),
'Right Hemisphere RDM': euclidean_distances(rh_fmri)}
plot_rdm(neural_rdms)
Due to the size of the RDMs, clusters are difficult to spot. We therefore zoom into a section to identify some clusters. (Feel free to play around with the zoom.)
plot_rdm(neural_rdms, zoom=[1200, 1500])
The above RDM plots show (dis-)similarity clustering in the neural activity in response to the viewing of different natural scene stimuli for both the left and right hemispheres. This means that some trial stimuli elecit more (dis-)similar brain activity than other trials. Intuitively this makes sense, as the COCO dataset contains a wide variety of stimuli that vary in similarity, not unlike the clustering in the study of Jozwik et al. (2016). In other words, this clustering of RDMs of COCO images occurs because the neural representations of images with similar content or features may share similar patterns of activation in the brain or neural network. For example, if we compute the RDMs of the neural responses to images with similar backgrounds or scenes, such as images of beaches or mountains, the RDMs for these images may cluster together because the neural representations of the scene images share some common features such as colors, textures, and shapes that are unique to that scene. However, in this study we did not look into what COCO images showed high (dis-)similarity, but presumably the clustering is comparable to the clusting in Jozwik et al. (2016).
Next, we evaluate what model performance we could expect, not only compared to other groups, but given the data and its associated signal and noise levels. To that extent, we decided to investigate the noise ceiling of our data, which gives us an indication of how much variance our model should preferably explain and the range of noise within the data. As subjects viewed different pictures for the majority of the experiment, averaging across their full neural RDMs is more difficult, especially as the number of trials differ. We therefore extracted the images that were shared across subjects and calculated the noise ceiling based on the corresponding subset of neural observations.
subs = [f"subj0{i}" for i in range(1, 9)]
neural_rdms_shared = neural_rdm_shared(subs = subs)
# Calculate Noise Ceiling
NCs = np.zeros((len(subs), 2))
rdms_subjects = np.mean(neural_rdms_shared, axis=1)
rdms_average_upp = np.mean(rdms_subjects, axis=0)
rdms_average_low = [np.mean(rdms_subjects[np.arange(len(subs)) != i,:,:], axis=0) for i in range(len(subs))]
# Computing the Individual Lower Noise Ceiling Bound
NCs[:, 0] = [spearmanr(squareform(rdms_subjects[i,:,:]), squareform(rdms_average_low[i]))[0] for i in range(len(subs))]
# Computing the Individual Upper Noise Ceiling Bound
NCs[:, 1] = [spearmanr(squareform(rdms_subjects[i,:,:]), squareform(rdms_average_upp))[0] for i in range(len(subs))]
# Compute Average Upper & Lower Noise Ceiling Bound
mean_NC = np.mean(NCs, axis=0)
print(f'The Noise Ceiling for the shared images has a Lower Bound of {mean_NC[0]:.2f} and an Upper Bound of {mean_NC[1]:.2f}')
(noiseCeiling := pd.DataFrame({"LowerBound": [mean_NC[0]], "UpperBound": [mean_NC[1]]}))
The Noise Ceiling for the shared images has a Lower Bound of 0.53 and an Upper Bound of 0.65
| LowerBound | UpperBound | |
|---|---|---|
| 0 | 0.5344 | 0.654185 |
# Plot Noise Ceiling
plt.figure(figsize=(10,5))
plt.plot([0,1], [noiseCeiling['UpperBound'], noiseCeiling['UpperBound']], 'r--', label='Upper bound')
plt.plot([0,1], [noiseCeiling['LowerBound'], noiseCeiling['LowerBound']], 'b--', label='Lower bound')
plt.fill_between([0,1], noiseCeiling['UpperBound'], noiseCeiling['LowerBound'], color='grey', alpha=0.5)
plt.fill_between([0,1], noiseCeiling['LowerBound'], 0, color='khaki', alpha=0.2)
plt.ylabel('Spearman correlation')
plt.ylim([0,1])
plt.xticks([])
plt.title('Noise ceiling')
plt.legend()
plt.margins(x=0.01)
plt.show()
The Noise Ceiling indicates that there is a substantial amount of variance that can be explained and that the amount not reliably shared across subjects (i.e., noise) is rather thin/small. This finding gives us a good indication of what to expect when evaluating the later model fits.
All images consist of natural scenes coming from the COCO dataset. The images are divided into a training and a test split (corresponding to the fMRI training and test data splits). The amount of training and test images varies between subjects.
# Load image paths
dirs.load_image_paths()
train_img_list, test_img_list = dirs.train_img_list, dirs.test_img_list
print(f"Training images: {len(train_img_list)}")
print(f"Test images: {len(test_img_list)}")
Training images: 9841 Test images: 159
All NSD images also have a caption associated with them that can be traced back to their origin in the COCO dataset (visible above each image in the dataset). First, a matching between new NSD and old COCO IDs was retrieved, available in the NSD AWS repository under nsd_stim_info_merged.csv. Next, JSON files containing the captions and corresponding IDs were downlaoded from the official COCO Website. Lastly, captions were matched based on old and new IDs.
# Load captions
dirs.load_captions()
train_cap_list, test_cap_list = dirs.train_cap_list, dirs.test_cap_list
max_caption_len = processor(text=train_cap_list + test_cap_list, return_tensors="pt", padding=True)["input_ids"].shape[1]
print(f"Training captions: {len(train_cap_list)}")
print(f"Test captions: {len(test_cap_list)}")
print(f"Max caption length: {max_caption_len}")
Training captions: 9841 Test captions: 159 Max caption length: 46
The visual cortex is divided into multiple areas having different functional properties, referred to as regions-of-interest (ROIs). Along with the fMRI data, ROI indices are provided for selecting vertices belonging to specific visual ROIs.
Following is the list of ROIs (ROI class file names in parenthesis):
Next we plot the vertices belonging to a specified fMRI ROI. An additional example image (stimulus) with caption is displayed.
img = 354 #@param
hemisphere = "left" # ['left', 'right']
roi = "EBA" # ["all-vertices", "V1v", "V1d", "V2v", "V2d", "V3v", "V3d", "hV4", "EBA", "FBA-1", "FBA-2", "mTL-bodies", "OFA", "FFA-1", "FFA-2", "mTL-faces", "aTL-faces", "OPA", "PPA", "RSC", "OWFA", "VWFA-1", "VWFA-2", "mfs-words", "mTL-words", "early", "midventral", "midlateral", "midparietal", "ventral", "lateral", "parietal"]
# Load the image
img_dir = os.path.join(train_img_list[img])
train_img = Image.open(img_dir).convert('RGB')
# Plot the image
plt.figure()
plt.axis('off')
plt.imshow(train_img)
plt.title('Training image: ' + str(img+1) + '\n' + train_cap_list[img]);
fsaverage_roi_class, roi_map = dirs.load_roi_masks(roi, "lh" if hemisphere == "left" else "rh")
if roi != "all-vertices":
dirs.load_challenge_rois()
challenge_roi_class = dirs.lh_challenge_rois if hemisphere == "left" else dirs.rh_challenge_rois
roi_mapping = list(roi_map.keys())[list(roi_map.values()).index(roi)]
fsaverage_roi = np.asarray(fsaverage_roi_class == roi_mapping, dtype=int)
challenge_roi = np.asarray(challenge_roi_class[roi_class_index(roi)] == roi_mapping, dtype=int)
# Map the fMRI data onto the brain surface map
fsaverage_response = np.zeros(len(fsaverage_roi))
if hemisphere == 'left':
fsaverage_response[np.where(fsaverage_roi)[0]] = lh_fmri[img,np.where(challenge_roi)[0]]
elif hemisphere == 'right':
fsaverage_response[np.where(fsaverage_roi)[0]] = rh_fmri[img,np.where(challenge_roi)[0]]
else:
fsaverage_response = np.zeros(len(fsaverage_roi_class))
if hemisphere == 'left':
fsaverage_response[np.where(fsaverage_roi_class)[0]] = lh_fmri[img]
elif hemisphere == 'right':
fsaverage_response[np.where(fsaverage_roi_class)[0]] = rh_fmri[img]
# # Create the interactive brain surface map
fsaverage = datasets.fetch_surf_fsaverage('fsaverage')
view = plotting.view_surf(
surf_mesh=fsaverage['infl_'+hemisphere],
surf_map=fsaverage_response,
bg_map=fsaverage['sulc_'+hemisphere],
threshold=1e-14,
cmap='cold_hot',
colorbar=True,
title=roi+', '+hemisphere+' hemisphere'
)
view